"Functions loading the .pkl version preprocessed data"
from tensorpack import dataflow
from glob import glob
import pickle
import os
import numpy as np


class ArgoversePklLoader(dataflow.RNGDataFlow):
    def __init__(self, data_path: str, shuffle: bool=True, max_lane_nodes=650):
        super(ArgoversePklLoader, self).__init__()
        self.data_path = data_path
        self.shuffle = shuffle
        self.max_lane_nodes = max_lane_nodes
        
    def __iter__(self):
        pkl_list = glob(os.path.join(self.data_path, '*'))
        pkl_list.sort()
        if self.shuffle:
            self.rng.shuffle(pkl_list)
            
        for pkl_path in pkl_list:
            with open(pkl_path, 'rb') as f:
                data = pickle.load(f)
                
            if data['lane'][0].shape[0] > self.max_lane_nodes:
                continue
                
            data = {k:v[0] for k, v in data.items()}
            lane_mask = np.zeros(self.max_lane_nodes, dtype=np.float32)
            lane_mask[:len(data['lane'])] = 1.0
            
            data['lane'] = self.expand_particle(data['lane'], self.max_lane_nodes, 0)
            data['lane_norm'] = self.expand_particle(data['lane_norm'], self.max_lane_nodes, 0)
            # data['lane_norm'] /= np.linalg.norm(data['lane_norm'], axis=-1)[...,np.newaxis]
            data['lane_mask'] = lane_mask
            
            yield data
            
    def __len__(self):
        return len(glob(os.path.join(self.data_path, '*')))
    
    @classmethod
    def expand_particle(cls, arr, max_num, axis, value_type='int'):
        dummy_shape = list(arr.shape)
        dummy_shape[axis] = max_num - arr.shape[axis]
        dummy = np.zeros(dummy_shape)
        if value_type == 'str':
            dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
        return np.concatenate([arr, dummy], axis=axis)
    

def read_pkl_data(data_path: str, batch_size: int, 
                  shuffle: bool=False, repeat: bool=False, **kwargs):
    df = ArgoversePklLoader(data_path=data_path, shuffle=shuffle, **kwargs)
    if repeat:
        df = dataflow.RepeatedData(df, -1)
    df = dataflow.BatchData(df, batch_size=batch_size, use_list=True)
    df.reset_state()
    return df

